{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ATIS data preprocessing\n",
    "\n",
    "read the `*.iob` files from `JointSLU` [github repo](https://github.com/yvchen/JointSLU/blob/master/program/BasicModel.py)\n",
    "\n",
    "following the `JointSLU` experiment, we train with `atis-2.train`, validate with `atis-2.dev` and test with `atis-test`\n",
    "\n",
    "split into tokenized sentences, entity lists, and intent list\n",
    "\n",
    "train a gensim w2v model on the training data\n",
    "\n",
    "encode the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/derek/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/derek/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/derek/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "from collections import Counter\n",
    "from preprocessing import CharacterIndexer, SlotIndexer, IntentIndexer\n",
    "from gensim.models import Word2Vec\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### read the file data\n",
    "\n",
    "the files are already preprocessed (lower-cased and stripped of punctuation)\n",
    "\n",
    "all we will do is number replacement"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def readatis(filename='data/atis/atis-2.train.w-intent.iob'):\n",
    "    \"\"\"\n",
    "    function for reading the ATIS \n",
    "    \"\"\"\n",
    "    data = pd.read_csv(filename, sep='\\t', header=None)\n",
    "    # get sentences and ner labels\n",
    "    sents = [s.split() for s in data[0].tolist()]\n",
    "    ners  = [s.split() for s in data[1].tolist()]\n",
    "    # for sents, replace digits\n",
    "    for i, sent in enumerate(sents):\n",
    "        sent = ' '.join(sent)\n",
    "        for d in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']:\n",
    "            sent = sent.replace(d, '#')\n",
    "        sents[i] = sent.split()\n",
    "    # check lengths\n",
    "    assert(len(sents)==len(ners))\n",
    "    # the intent label is the last item of ners.\n",
    "    # remove it and replace it with a 'O' null tag\n",
    "    ints = [s[-1] for s in ners]\n",
    "    ners = [s[:-1]+['O'] for s in ners]\n",
    "    # check sent, ner, int lengths\n",
    "    assert(len(sents)==len(ints))\n",
    "    for i in range(len(sents)):\n",
    "        assert(len(sents[i])==len(ners[i]))\n",
    "    return sents, ners, ints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "trn_texts, trn_slots, trn_ints = readatis('data/atis/atis-2.train.w-intent.iob')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "dev_texts, dev_slots, dev_ints = readatis('data/atis/atis-2.dev.w-intent.iob')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "tst_texts, tst_slots, tst_ints = readatis('data/atis/atis.test.w-intent.iob')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4478, 500, 893)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(trn_texts), len(dev_texts), len(tst_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### reduce slot names\n",
    "\n",
    "we can only consider the macro-level slot names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def slotsplitter(slotslist):\n",
    "#     newlist = []\n",
    "#     for s in slotslist:\n",
    "#         newsent = [i.split('.')[0] for i in s]\n",
    "#         newlist.append(newsent)\n",
    "#     return newlist\n",
    "\n",
    "# trn_slots = slotsplitter(trn_slots)\n",
    "# dev_slots = slotsplitter(dev_slots)\n",
    "# tst_slots = slotsplitter(tst_slots)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### visual test of data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "120"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# view set of slot tags\n",
    "len(list(set([t for s in trn_slots for t in s]))) # , list(set([t for s in trn_ners for t in s]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('O', 41022),\n",
       " ('B-toloc.city_name', 3919),\n",
       " ('B-fromloc.city_name', 3892),\n",
       " ('I-toloc.city_name', 987),\n",
       " ('B-depart_date.day_name', 785),\n",
       " ('B-airline_name', 639),\n",
       " ('I-fromloc.city_name', 632),\n",
       " ('B-depart_time.period_of_day', 521),\n",
       " ('I-airline_name', 379),\n",
       " ('B-depart_date.day_number', 355)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter([t for s in trn_slots for t in s]).most_common(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# view set of intents\n",
    "len(list(set(trn_ints))) #, list(set(trn_ints))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "13.276686020544886"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slens = [len(s) for s in trn_texts]\n",
    "np.mean(slens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "95.73470299240732"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "swuts = [1 if l < 22 else 0 for l in slens]\n",
    "sum(swuts)*100/len(swuts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('atis_flight', 3309),\n",
       " ('atis_airfare', 385),\n",
       " ('atis_ground_service', 230),\n",
       " ('atis_airline', 139),\n",
       " ('atis_abbreviation', 130),\n",
       " ('atis_aircraft', 70),\n",
       " ('atis_flight_time', 45),\n",
       " ('atis_quantity', 41),\n",
       " ('atis_flight#atis_airfare', 19),\n",
       " ('atis_city', 18)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# the hashtagged elements seem to be for multi-label predictions\n",
    "# we will treat them as a separate joint label for now\n",
    "Counter(trn_ints).most_common(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "txt: 12 ['BOS', 'is', 'there', 'a', 'delta', 'flight', 'from', 'denver', 'to', 'san', 'francisco', 'EOS']\n",
      "ent: 12 ['O', 'O', 'O', 'O', 'B-airline_name', 'O', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'I-toloc.city_name', 'O']\n",
      "int: atis_flight\n",
      "\n",
      "txt: 14 ['BOS', \"i'd\", 'like', 'a', 'twa', 'flight', 'from', 'las', 'vegas', 'to', 'new', 'york', 'nonstop', 'EOS']\n",
      "ent: 14 ['O', 'O', 'O', 'O', 'B-airline_code', 'O', 'O', 'B-fromloc.city_name', 'I-fromloc.city_name', 'O', 'B-toloc.city_name', 'I-toloc.city_name', 'B-flight_stop', 'O']\n",
      "int: atis_flight\n",
      "\n",
      "txt: 12 ['BOS', 'tell', 'me', 'about', 'ground', 'transportation', 'between', 'orlando', 'international', 'and', 'orlando', 'EOS']\n",
      "ent: 12 ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-fromloc.airport_name', 'I-fromloc.airport_name', 'O', 'B-toloc.city_name', 'O']\n",
      "int: atis_ground_service\n",
      "\n",
      "txt: 23 ['BOS', 'what', 'are', 'the', 'nonstop', 'flights', 'on', 'america', 'west', 'or', 'southwest', 'air', 'from', 'kansas', 'city', 'to', 'burbank', 'on', 'saturday', 'may', 'twenty', 'two', 'EOS']\n",
      "ent: 23 ['O', 'O', 'O', 'O', 'B-flight_stop', 'O', 'O', 'B-airline_name', 'I-airline_name', 'B-or', 'B-airline_name', 'O', 'O', 'B-fromloc.city_name', 'I-fromloc.city_name', 'O', 'B-toloc.city_name', 'O', 'B-arrive_date.day_name', 'B-arrive_date.month_name', 'B-arrive_date.day_number', 'I-arrive_date.day_number', 'O']\n",
      "int: atis_flight\n",
      "\n",
      "txt: 13 ['BOS', \"what's\", 'the', 'first', 'flight', 'after', '#', 'pm', 'leaving', 'washington', 'to', 'denver', 'EOS']\n",
      "ent: 13 ['O', 'O', 'O', 'B-flight_mod', 'O', 'B-depart_time.time_relative', 'B-depart_time.time', 'I-depart_time.time', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'O']\n",
      "int: atis_flight\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(1, 6):\n",
    "    print('txt:', len(trn_texts[-i]), trn_texts[-i])\n",
    "    print('ent:', len(trn_slots[-i]), trn_slots[-i])\n",
    "    print('int:', trn_ints[-i])\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### train word2vec embeddings\n",
    "\n",
    "using pretrained embeddings has been found to increase performance in various NLP tasks.\n",
    "\n",
    "we could train on an external corpus such as brown but we will just use the training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# first, remove the BOS and EOS tags from the training sentences\n",
    "w2v_text = [s[1:-1] for s in trn_texts]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training done!\n"
     ]
    }
   ],
   "source": [
    "# train and save model\n",
    "model = Word2Vec(w2v_text, size=200, min_count=1, window=5, workers=3, iter=5)\n",
    "model.save('model/atis_w2v.gensimmodel')\n",
    "print('training done!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get model vocabulary\n",
    "vocab = dict([(k, v.index) for k, v in model.wv.vocab.items()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('airline', 0.9996054172515869),\n",
       " ('september', 0.9995239973068237),\n",
       " ('into', 0.9994678497314453),\n",
       " ('flying', 0.9994544982910156),\n",
       " ('eastern', 0.9994508028030396),\n",
       " ('has', 0.9994401931762695),\n",
       " ('cities', 0.9994279146194458),\n",
       " ('other', 0.999421238899231),\n",
       " ('stopping', 0.9994186758995056),\n",
       " ('out', 0.999413013458252)]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test\n",
    "model.wv.most_similar('delta')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('for', 0.9986646175384521),\n",
       " ('economy', 0.9985247850418091),\n",
       " ('schedule', 0.9984636902809143),\n",
       " ('service', 0.9984104037284851),\n",
       " ('flying', 0.9983426332473755),\n",
       " ('december', 0.9983159899711609),\n",
       " ('only', 0.9983118772506714),\n",
       " ('times', 0.9983005523681641),\n",
       " ('via', 0.9982900619506836),\n",
       " ('any', 0.9982751607894897)]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.wv.most_similar('dallas')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## index the sentences\n",
    "\n",
    "we will use premade classes for this that will integer-index and pad sentences to fixed lengths.\n",
    "\n",
    "we will truncate max sentence and word lengths to mean + 2*st.dev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fit(): splitting...\n",
      "fit(): max sent len set to 22\n",
      "fit(): max word len set to 9\n",
      "fit(): creating conversion dictionaries...\n",
      "fit(): tru word vocab: 728\n",
      "fit(): tru char vocab: 35\n",
      "fit(): done!\n"
     ]
    }
   ],
   "source": [
    "# instantiate a sentence indexer and fit to the training data\n",
    "sentindexer = CharacterIndexer(max_sent_mode='std')\n",
    "sentindexer.fit(trn_texts, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4478, 22), (500, 22), (893, 22), (4478, 22, 9))"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform the sentence data\n",
    "trn_text_idx, trn_char_idx = sentindexer.transform(trn_texts)\n",
    "dev_text_idx, dev_char_idx = sentindexer.transform(dev_texts)\n",
    "tst_text_idx, tst_char_idx = sentindexer.transform(tst_texts)\n",
    "trn_text_idx.shape, dev_text_idx.shape, tst_text_idx.shape, trn_char_idx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fit(): labels set to size: 121\n"
     ]
    }
   ],
   "source": [
    "# instantiate a slot indexer and fit to the training data\n",
    "slotindexer = SlotIndexer(max_len=sentindexer.max_sent_len)\n",
    "slotindexer.fit(trn_slots, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4478, 22, 1), (500, 22, 1), (893, 22, 1))"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform the slot data\n",
    "trn_slot_idx = slotindexer.transform(trn_slots)\n",
    "dev_slot_idx = slotindexer.transform(dev_slots)\n",
    "tst_slot_idx = slotindexer.transform(tst_slots)\n",
    "trn_slot_idx.shape, dev_slot_idx.shape, tst_slot_idx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fit(): labels set to size: 22\n"
     ]
    }
   ],
   "source": [
    "intindexer = IntentIndexer()\n",
    "intindexer.fit(trn_ints, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((4478, 1), (500, 1), (893, 1))"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform the intent data\n",
    "trn_int_idx = intindexer.transform(trn_ints)\n",
    "dev_int_idx = intindexer.transform(dev_ints)\n",
    "tst_int_idx = intindexer.transform(tst_ints)\n",
    "trn_int_idx.shape, dev_int_idx.shape, tst_int_idx.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test for NaNs, out-of-bounds indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n",
      "[False]\n"
     ]
    }
   ],
   "source": [
    "print(np.unique(np.isnan(trn_text_idx)))\n",
    "print(np.unique(np.isnan(trn_char_idx)))\n",
    "print(np.unique(np.isnan(dev_text_idx)))\n",
    "print(np.unique(np.isnan(dev_char_idx)))\n",
    "print(np.unique(np.isnan(tst_text_idx)))\n",
    "print(np.unique(np.isnan(tst_char_idx)))\n",
    "\n",
    "print(np.unique(np.isnan(trn_slot_idx)))\n",
    "print(np.unique(np.isnan(dev_slot_idx)))\n",
    "print(np.unique(np.isnan(tst_slot_idx)))\n",
    "\n",
    "print(np.unique(np.isnan(trn_int_idx)))\n",
    "print(np.unique(np.isnan(dev_int_idx)))\n",
    "print(np.unique(np.isnan(tst_int_idx)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "728\n",
      "[726]\n",
      "[727]\n",
      "[727]\n",
      "35\n",
      "[33]\n",
      "[32]\n",
      "[33]\n",
      "121\n",
      "[119]\n",
      "[115]\n",
      "[115]\n",
      "22\n",
      "[21]\n",
      "[17]\n",
      "[16]\n"
     ]
    }
   ],
   "source": [
    "print(sentindexer.max_word_vocab)\n",
    "print(np.unique(np.max(trn_text_idx)))\n",
    "print(np.unique(np.max(dev_text_idx)))\n",
    "print(np.unique(np.max(tst_text_idx)))\n",
    "print(sentindexer.max_char_vocab)\n",
    "print(np.unique(np.max(trn_char_idx)))\n",
    "print(np.unique(np.max(dev_char_idx)))\n",
    "print(np.unique(np.max(tst_char_idx)))\n",
    "print(slotindexer.labelsize)\n",
    "print(np.unique(np.max(trn_slot_idx)))\n",
    "print(np.unique(np.max(dev_slot_idx)))\n",
    "print(np.unique(np.max(tst_slot_idx)))\n",
    "print(intindexer.labelsize)\n",
    "print(np.unique(np.max(trn_int_idx)))\n",
    "print(np.unique(np.max(dev_int_idx)))\n",
    "print(np.unique(np.max(tst_int_idx)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### test decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['BOS',\n",
       "  'i',\n",
       "  'would',\n",
       "  'like',\n",
       "  'to',\n",
       "  'find',\n",
       "  'a',\n",
       "  'flight',\n",
       "  'from',\n",
       "  'charlotte',\n",
       "  'to',\n",
       "  'las',\n",
       "  'vegas',\n",
       "  'that',\n",
       "  'makes',\n",
       "  'a',\n",
       "  'stop',\n",
       "  'in',\n",
       "  'st.',\n",
       "  'louis',\n",
       "  'EOS']]"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sentindexer.inverse_transform(tst_text_idx[0:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[['O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'B-fromloc.city_name',\n",
       "  'O',\n",
       "  'B-toloc.city_name',\n",
       "  'I-toloc.city_name',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'O',\n",
       "  'B-stoploc.city_name',\n",
       "  'I-stoploc.city_name',\n",
       "  'O']]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "slotindexer.inverse_transform(tst_slot_idx[0:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['atis_flight', 'atis_airfare', 'atis_flight', 'atis_flight', 'atis_flight']"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "intindexer.inverse_transform(tst_int_idx[0:5])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### save all items"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save transformers\n",
    "pickle.dump(sentindexer, open('encoded/atis_sentindexer.pkl', 'wb'))\n",
    "pickle.dump(slotindexer, open('encoded/atis_slotindexer.pkl', 'wb'))\n",
    "pickle.dump(intindexer,  open('encoded/atis_intindexer.pkl',  'wb'))\n",
    "\n",
    "# save word2vec vocab\n",
    "pickle.dump(vocab, open('model/atis_w2v_vocab.pkl',  'wb'))\n",
    "\n",
    "# save text data\n",
    "pickle.dump(trn_texts, open('encoded/trn_texts_raw.pkl', 'wb'))\n",
    "pickle.dump(dev_texts, open('encoded/dev_texts_raw.pkl', 'wb'))\n",
    "pickle.dump(tst_texts, open('encoded/tst_texts_raw.pkl',  'wb'))\n",
    "\n",
    "pickle.dump(trn_slots, open('encoded/trn_slots_raw.pkl', 'wb'))\n",
    "pickle.dump(dev_slots, open('encoded/dev_slots_raw.pkl', 'wb'))\n",
    "pickle.dump(tst_slots, open('encoded/tst_slots_raw.pkl',  'wb'))\n",
    "\n",
    "pickle.dump(trn_ints, open('encoded/trn_ints_raw.pkl', 'wb'))\n",
    "pickle.dump(dev_ints, open('encoded/dev_ints_raw.pkl', 'wb'))\n",
    "pickle.dump(tst_ints, open('encoded/tst_ints_raw.pkl',  'wb'))\n",
    "\n",
    "# save encoded data\n",
    "np.save('encoded/trn_text_idx.npy', trn_text_idx)\n",
    "np.save('encoded/dev_text_idx.npy', dev_text_idx)\n",
    "np.save('encoded/tst_text_idx.npy', tst_text_idx)\n",
    "\n",
    "np.save('encoded/trn_char_idx.npy', trn_char_idx)\n",
    "np.save('encoded/dev_char_idx.npy', dev_char_idx)\n",
    "np.save('encoded/tst_char_idx.npy', tst_char_idx)\n",
    "\n",
    "np.save('encoded/trn_slot_idx.npy', trn_slot_idx)\n",
    "np.save('encoded/dev_slot_idx.npy', dev_slot_idx)\n",
    "np.save('encoded/tst_slot_idx.npy', tst_slot_idx)\n",
    "\n",
    "np.save('encoded/trn_int_idx.npy', trn_int_idx)\n",
    "np.save('encoded/dev_int_idx.npy', dev_int_idx)\n",
    "np.save('encoded/tst_int_idx.npy', tst_int_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "base"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
